{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "e36e18bb-f797-4d5f-96a4-3d10350f93cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['OMP_NUM_THREADS'] = '1'\n",
    "import argparse\n",
    "import torch\n",
    "from env import create_train_env\n",
    "from model import ActorCritic\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "c0326d9a-420f-4097-87cf-4ae5bfe84d61",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_args():\n",
    "    parser = argparse.ArgumentParser(\n",
    "        \"\"\"Implementation of model described in the paper: Asynchronous Methods for Deep Reinforcement Learning for Super Mario Bros\"\"\")\n",
    "    parser.add_argument(\"--world\", type=int, default=1)\n",
    "    parser.add_argument(\"--stage\", type=int, default=1)\n",
    "    parser.add_argument(\"--action_type\", type=str, default=\"complex\")\n",
    "    parser.add_argument(\"--saved_path\", type=str, default=\"trained_models\")\n",
    "    parser.add_argument(\"--output_path\", type=str, default=\"output\")\n",
    "    args = parser.parse_args('')\n",
    "    return args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ef7de497-7fd6-482c-b46a-7a588eca4e52",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(opt):\n",
    "    torch.manual_seed(123)\n",
    "    env, num_states, num_actions = create_train_env(opt.world, opt.stage, opt.action_type,\n",
    "                                                    \"{}/video_{}_{}.mp4\".format(opt.output_path, opt.world, opt.stage))\n",
    "    model = ActorCritic(num_states, num_actions)\n",
    "    if torch.cuda.is_available():\n",
    "        model.load_state_dict(torch.load(\"{}/A3CSuperMarioBros_{}_{}\".format(opt.saved_path, opt.world, opt.stage)))\n",
    "        model.cuda()\n",
    "    else:\n",
    "        model.load_state_dict(torch.load(\"{}/A3CSuperMarioBros_{}_{}\".format(opt.saved_path, opt.world, opt.stage),\n",
    "                                         map_location=lambda storage, loc: storage))\n",
    "    model.eval()\n",
    "    state = torch.from_numpy(env.reset())\n",
    "    done = True\n",
    "    while True:\n",
    "        if done:\n",
    "            h_0 = torch.zeros((1, 512), dtype=torch.float)\n",
    "            c_0 = torch.zeros((1, 512), dtype=torch.float)\n",
    "            env.reset()\n",
    "        else:\n",
    "            h_0 = h_0.detach()\n",
    "            c_0 = c_0.detach()\n",
    "        if torch.cuda.is_available():\n",
    "            h_0 = h_0.cuda()\n",
    "            c_0 = c_0.cuda()\n",
    "            state = state.cuda()\n",
    "\n",
    "        logits, value, h_0, c_0 = model(state, h_0, c_0)\n",
    "        policy = F.softmax(logits, dim=1)\n",
    "        action = torch.argmax(policy).item()\n",
    "        action = int(action)\n",
    "        state, reward, done, info = env.step(action)\n",
    "        state = torch.from_numpy(state)\n",
    "        env.render()\n",
    "        if info[\"flag_get\"]:\n",
    "            print(\"World {} stage {} completed\".format(opt.world, opt.stage))\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd5857ae-c950-4e38-918c-df6f53d4fc79",
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = get_args()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1572d4d-059d-4ff3-812d-8a23959963ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "test(opt)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
